{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TtlIRiNXWlQ0"
      },
      "source": [
        "# Waste identification with instance segmentation in TensorFlow"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ohoMgYgXWsIO"
      },
      "source": [
        "Welcome to the Instance Segmentation Colab! This notebook will take you through the steps of running an \"out-of-the-box\" Mask RCNN Instance Segmentation model on images."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8PKG9z4VYPEs"
      },
      "source": [
        "To finish this task, a proper path for the saved models and a single image needs to be provided. The path to the labels on which the models are trained is in the waste_identification_ml directory inside the Tensorflow Model Garden repository. The label files are inferred automatically for both models."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j7yl9CqgYWvS"
      },
      "source": [
        "## Imports and Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HhMR363skKaY"
      },
      "outputs": [],
      "source": [
        "!pip install -q tf_keras\n",
        "!pip install -q webcolors==1.13"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ELUFMVDDAopS"
      },
      "outputs": [],
      "source": [
        "from six.moves.urllib.request import urlopen\n",
        "from six import BytesIO\n",
        "from PIL import Image\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "import sys\n",
        "import matplotlib.pyplot as plt\n",
        "import matplotlib\n",
        "import logging\n",
        "import pandas as pd\n",
        "\n",
        "logging.disable(logging.WARNING)\n",
        "\n",
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "imIpwZgv_3dE"
      },
      "source": [
        "Run the following cell to import utility functions that will be needed for pre-processing, post-processing and color detection.\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_77YK3a_BCg_"
      },
      "source": [
        "To visualize the images with the proper detected boxes and segmentation masks, we will use the TensorFlow Object Detection API. To install it we will clone the repo.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "qhk_NujKO0mb"
      },
      "outputs": [],
      "source": [
        "# Clone the tensorflow models repository.\n",
        "!git clone --depth 1 https://github.com/tensorflow/models 2\u003e/dev/null"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fBAdlmHKO3AV"
      },
      "outputs": [],
      "source": [
        "# Download the script to pull instance segmentation model weights from the TF Model Garden repo.\n",
        "url = (\n",
        "    \"https://raw.githubusercontent.com/\"\n",
        "    \"tensorflow/models/master/\"\n",
        "    \"official/projects/waste_identification_ml/\"\n",
        "    \"model_inference/download_and_unzip_models.py\"\n",
        ")\n",
        "\n",
        "!wget -q {url}"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "o1dYyG55BtWb"
      },
      "outputs": [],
      "source": [
        "sys.path.append('models/research/')\n",
        "from object_detection.utils import visualization_utils as viz_utils"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cbcEDJHAB65J"
      },
      "outputs": [],
      "source": [
        "sys.path.append('models/official/projects/waste_identification_ml/model_inference/')\n",
        "import preprocessing\n",
        "import postprocessing\n",
        "import color_and_property_extractor\n",
        "import labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nq2DNpXQ_0-n"
      },
      "source": [
        "## Utilities"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GO488S78_2GJ"
      },
      "outputs": [],
      "source": [
        "def load_image_into_numpy_array(path):\n",
        "  \"\"\"Load an image from file into a numpy array.\n",
        "\n",
        "  Puts image into numpy array to feed into tensorflow graph.\n",
        "  Note that by convention we put it into a numpy array with shape\n",
        "  (height, width, channels), where channels=3 for RGB.\n",
        "\n",
        "  Args:\n",
        "    path: the file path to the image\n",
        "\n",
        "  Returns:\n",
        "    uint8 numpy array with shape (1, h, w, 3)\n",
        "  \"\"\"\n",
        "  image = None\n",
        "  if(path.startswith('http')):\n",
        "    response = urlopen(path)\n",
        "    image_data = response.read()\n",
        "    image_data = BytesIO(image_data)\n",
        "    image = Image.open(image_data)\n",
        "  else:\n",
        "    image_data = tf.io.gfile.GFile(path, 'rb').read()\n",
        "    image = Image.open(BytesIO(image_data))\n",
        "\n",
        "  (im_width, im_height) = image.size\n",
        "  return np.array(image.getdata()).reshape(\n",
        "      (1, im_height, im_width, 3)).astype(np.uint8)\n",
        "\n",
        "\n",
        "def load_model(model_handle):\n",
        "    \"\"\"Loads a TensorFlow SavedModel and returns a function that can be used to make predictions.\n",
        "\n",
        "    Args:\n",
        "      model_handle: A path to a TensorFlow SavedModel.\n",
        "\n",
        "    Returns:\n",
        "      A function that can be used to make predictions.\n",
        "    \"\"\"\n",
        "    print('loading model...')\n",
        "    print(model_handle)\n",
        "    model = tf.saved_model.load(model_handle)\n",
        "    print('model loaded!')\n",
        "    detection_fn = model.signatures['serving_default']\n",
        "    return detection_fn\n",
        "\n",
        "\n",
        "def perform_detection(model, image):\n",
        "  \"\"\"Performs Mask RCNN on an image using the specified model.\n",
        "\n",
        "  Args:\n",
        "    model: A function that can be used to make predictions.\n",
        "    image_np: A NumPy array representing the image to be detected.\n",
        "\n",
        "  Returns:\n",
        "    A list of detections.\n",
        "  \"\"\"\n",
        "  detection_fn = model(image)\n",
        "  detection_fn = {key: value.numpy() for key, value in detection_fn.items()}\n",
        "  return detection_fn"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "t7d00cJH-68Z"
      },
      "outputs": [],
      "source": [
        "# 'material_model' output is both material and its sub type e.g. Plastics_PET.\n",
        "# 'material_form_model' outputs the form of an object e.g. can, bottle, etc.\n",
        "MODEL_WEIGHTS_DICT = {\n",
        "  'MODELS_WEIGHTS_RESNET_V1' : {\n",
        "      'material_url': (\n",
        "          'https://storage.googleapis.com/tf_model_garden/vision/'\n",
        "          'waste_identification_ml/two_model_strategy/material_resnet_v1.zip'\n",
        "      ),\n",
        "      'material_form_url': (\n",
        "          'https://storage.googleapis.com/tf_model_garden/vision/'\n",
        "          'waste_identification_ml/two_model_strategy/material_form_resnet_v1.zip'\n",
        "      ),\n",
        "  },\n",
        "  'MODELS_WEIGHTS_RESNET_V2': {\n",
        "      'material_url': (\n",
        "          'https://storage.googleapis.com/tf_model_garden/vision/'\n",
        "          'waste_identification_ml/two_model_strategy/material_resnet_v2.zip'\n",
        "      ),\n",
        "      'material_form_url': (\n",
        "          'https://storage.googleapis.com/tf_model_garden/vision/'\n",
        "          'waste_identification_ml/two_model_strategy/material_form_resnet_v2.zip'\n",
        "      ),\n",
        "  },\n",
        "  'MODELS_WEIGHTS_MOBILENET_V2': {\n",
        "      'material_url': (\n",
        "          'https://storage.googleapis.com/tf_model_garden/vision/'\n",
        "          'waste_identification_ml/two_model_strategy/material_mobilenet_v2.zip'\n",
        "      ),\n",
        "      'material_form_url': (\n",
        "          'https://storage.googleapis.com/tf_model_garden/vision/'\n",
        "          'waste_identification_ml/two_model_strategy/material_form_mobilenet_v2.zip'\n",
        "      ),\n",
        "  }\n",
        "}\n",
        "\n",
        "MODELS_RESNET_V1 = {\n",
        "'material_model' : 'material/material_resnet_v1/saved_model/',\n",
        "'material_form_model' : 'material_form/material_form_resnet_v1/saved_model/',\n",
        "}\n",
        "\n",
        "MODELS_RESNET_V2 = {\n",
        "'material_model' : 'material/material_resnet_v2/saved_model/',\n",
        "'material_form_model' : 'material_form/material_form_resnet_v2/saved_model/',\n",
        "}\n",
        "\n",
        "MODELS_MOBILENET_V2 = {\n",
        "'material_model' : 'material/material_mobilenet_v2/saved_model/',\n",
        "'material_form_model' : 'material_form/material_form_mobilenet_v2/saved_model/',\n",
        "}\n",
        "\n",
        "LABELS = {\n",
        "    'material_model': (\n",
        "        'models/official/projects/waste_identification_ml/pre_processing/'\n",
        "        'config/data/two_model_strategy_material.csv'\n",
        "    ),\n",
        "    'material_form_model': (\n",
        "        'models/official/projects/waste_identification_ml/pre_processing/'\n",
        "        'config/data/two_model_strategy_material_form.csv'\n",
        "    ),\n",
        "}\n",
        "\n",
        "# Path to a sample image stored in the repo.\n",
        "IMAGES_FOR_TEST = {\n",
        "    'Image1': (\n",
        "        'models/official/projects/waste_identification_ml/pre_processing/'\n",
        "        'config/sample_images/image_2.png'\n",
        "    )\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4XjfDEq--UlE"
      },
      "source": [
        "## Import pre-trained models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZQ435YHN3Lr-"
      },
      "outputs": [],
      "source": [
        "selected_model = \"MODELS_WEIGHTS_MOBILENET_V2\" #@param [\"MODELS_WEIGHTS_RESNET_V1\", \"MODELS_WEIGHTS_RESNET_V2\", \"MODELS_WEIGHTS_MOBILENET_V2\"]\n",
        "\n",
        "if selected_model == \"MODELS_WEIGHTS_RESNET_V1\":\n",
        "  ALL_MODELS = MODELS_RESNET_V1\n",
        "elif selected_model == \"MODELS_WEIGHTS_RESNET_V2\":\n",
        "  ALL_MODELS = MODELS_RESNET_V2\n",
        "elif selected_model == \"MODELS_WEIGHTS_MOBILENET_V2\":\n",
        "  ALL_MODELS = MODELS_MOBILENET_V2\n",
        "\n",
        "# Extract URLs based on the selected model\n",
        "url1 = MODEL_WEIGHTS_DICT[selected_model]['material_url']\n",
        "url2 = MODEL_WEIGHTS_DICT[selected_model]['material_form_url']\n",
        "\n",
        "# Download and unzip the selected model weights\n",
        "!python3 download_and_unzip_models.py $url1 $url2"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "W6mmyLsOJicF"
      },
      "source": [
        "## Load label map data"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PM2A29OrJqaU"
      },
      "source": [
        "Label maps correspond index numbers to category names, so that when our convolution network predicts 5, we know that this corresponds to airplane. Here we use internal utility functions, but anything that returns a dictionary mapping integers to appropriate string labels would be fine.\n",
        "\n",
        "We are going, for simplicity, to load from the repository that we loaded the Object Detection API code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5RUzrh0uegqt"
      },
      "outputs": [],
      "source": [
        "# The total number of predicted labels (category_indices) for a combined model = 741.\n",
        "category_indices, category_index = labels.load_labels(LABELS)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vBm2aQzfHhId"
      },
      "outputs": [],
      "source": [
        "# Display labels only for 'material' model.\n",
        "# Total number of labels for 'material' model = 19.\n",
        "category_indices[0]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Eh_Ey6lXHs8m"
      },
      "outputs": [],
      "source": [
        "# Display labels only for 'material_form' model.\n",
        "# Total number of labels for 'material form' model = 39.\n",
        "category_indices[1]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PFczkMGBClZ4"
      },
      "source": [
        "## Load pre-trained weights for both models."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5J6MgjOSC5JO"
      },
      "outputs": [],
      "source": [
        "# Loading both models.\n",
        "detection_fns = [load_model(model_path) for model_path in ALL_MODELS.values()]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VkdD8-QvGZ23"
      },
      "source": [
        "## Loading an image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BSXQF57FGba5"
      },
      "source": [
        "Let's try the model on a simple image. Here are some simple things to try out if you are curious:\n",
        "\n",
        "\n",
        "\n",
        "*   Try running inference on your own images, just upload them to colab and load the same way it's done in the cell below.\n",
        "*   Modify some of the input images and see if detection still works. Some simple things to try out here include flipping the image horizontally, or converting to grayscale (note that we still expect the input image to have 3 channels)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UVdlDwchGim_"
      },
      "source": [
        "Be careful: when using images with an alpha channel, the model expect 3 channels images and the alpha will count as a 4th."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-7ZS7gHgGk9f"
      },
      "outputs": [],
      "source": [
        "selected_image = 'Image1'\n",
        "flip_image_horizontally = False #@param {type:\"boolean\"}\n",
        "convert_image_to_grayscale = False #@param {type:\"boolean\"}\n",
        "\n",
        "image_path = IMAGES_FOR_TEST[selected_image]\n",
        "image_np = load_image_into_numpy_array(image_path)\n",
        "\n",
        "# Flip horizontally\n",
        "if (flip_image_horizontally):\n",
        "  image_np[0] = np.fliplr(image_np[0]).copy()\n",
        "\n",
        "# Convert image to grayscale\n",
        "if (convert_image_to_grayscale):\n",
        "  image_np[0] = np.tile(\n",
        "    np.mean(image_np[0], 2, keepdims=True), (1, 1, 3)).astype(np.uint8)\n",
        "\n",
        "print('min:', np.min(image_np[0]), 'max:', np.max(image_np[0]))\n",
        "plt.figure(figsize=(24,32))\n",
        "plt.imshow(image_np[0])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ztNkdbIwGnoH"
      },
      "source": [
        "## Pre-processing an image"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dCAhYIyzGUNj"
      },
      "outputs": [],
      "source": [
        "# Get an input size of images from one of the Instance Segmentation model.\n",
        "height = detection_fns[0].structured_input_signature[1]['inputs'].shape[1]\n",
        "width = detection_fns[0].structured_input_signature[1]['inputs'].shape[2]\n",
        "input_size = (height, width)\n",
        "print(input_size)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Uc_UWbRpGq1-"
      },
      "outputs": [],
      "source": [
        "image_np_cp = tf.image.resize(image_np[0], input_size, method=tf.image.ResizeMethod.AREA)\n",
        "image_np_cp = tf.cast(image_np_cp, tf.uint8)\n",
        "image_np = preprocessing.normalize_image(image_np_cp)\n",
        "image_np = tf.expand_dims(image_np, axis=0)\n",
        "image_np.get_shape()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wACzUTKWGxh4"
      },
      "outputs": [],
      "source": [
        "# Display pre-processed image.\n",
        "plt.figure(figsize=(24,32))\n",
        "plt.imshow(image_np[0])\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "H3r73X-FGzz-"
      },
      "source": [
        "## Doing the inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SqW1z96LGzmZ"
      },
      "outputs": [],
      "source": [
        "# Running inference with bith the models.\n",
        "results = list(map(lambda model: perform_detection(model, image_np), detection_fns))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "332GbmRuG5A9"
      },
      "source": [
        "## Merge results, extract properties and detect colors"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oxWFukUTG3d5"
      },
      "outputs": [],
      "source": [
        "use_generic_color = True #@param {type:\"boolean\"}\n",
        "\n",
        "SCORE_THRESH = 0.8\n",
        "\n",
        "no_detections_in_first = results[0]['num_detections'][0]\n",
        "no_detections_in_second = results[1]['num_detections'][0]\n",
        "\n",
        "if no_detections_in_first !=0 and no_detections_in_second != 0:\n",
        "  # Reframe the masks from the output of the model to its original size.\n",
        "  results = [postprocessing.reframing_masks(detection, height, width) for detection in results]\n",
        "\n",
        "  # Required to loop over the first max_detection values from both model outputs.\n",
        "  max_detection = max(no_detections_in_first, no_detections_in_second)\n",
        "\n",
        "  # This threshold will be used to eliminate all the detected objects whose area\n",
        "  # is greater than the 'area_threshold'.\n",
        "  area_threshold = 0.3 * np.product(image_np_cp.shape[:2])\n",
        "\n",
        "  # Align similar masks from both the model outputs and merge all the properties\n",
        "  # into a single mask. Function will only compare first 'max_detection' objects\n",
        "  # All the objects which have less than 'SCORE_THRESH' probability will be\n",
        "  # eliminated. All objects whose area is more than 'area_threshold' will be\n",
        "  # eliminated. 'category_dict' and 'category_index' are used to find the label\n",
        "  # from the combinations of labels from both individual models. The output\n",
        "  # should include masks appearing in either of the models if they qualify the\n",
        "  # criteria.\n",
        "  final_result = postprocessing.find_similar_masks(\n",
        "      results[0],\n",
        "      results[1],\n",
        "      max_detection,\n",
        "      SCORE_THRESH,\n",
        "      category_indices,\n",
        "      category_index,\n",
        "      area_threshold\n",
        "  )\n",
        "\n",
        "  # Convert normalized bounding box cocrdinates to their original format.\n",
        "  transformed_boxes = []\n",
        "  for bb in final_result['detection_boxes'][0]:\n",
        "      YMIN = int(bb[0]*height)\n",
        "      XMIN = int(bb[1]*width)\n",
        "      YMAX = int(bb[2]*height)\n",
        "      XMAX = int(bb[3]*width)\n",
        "      transformed_boxes.append([YMIN, XMIN, YMAX, XMAX])\n",
        "\n",
        "  # Filtering duplicate bounding boxes.\n",
        "  filtered_boxes, index_to_delete = (\n",
        "    postprocessing.filter_bounding_boxes(transformed_boxes))\n",
        "\n",
        "  # Removing the corresponding values from all keys of a dictionary which\n",
        "  # corresponds to the duplicate bounding boxes.\n",
        "  final_result['num_detections'][0] -= len(index_to_delete)\n",
        "  final_result['detection_classes'] = np.delete(\n",
        "      final_result['detection_classes'], index_to_delete)\n",
        "  final_result['detection_scores'] = np.delete(\n",
        "      final_result['detection_scores'], index_to_delete, axis=1)\n",
        "  final_result['detection_boxes'] = np.delete(\n",
        "      final_result['detection_boxes'], index_to_delete, axis=1)\n",
        "  final_result['detection_classes_names'] = np.delete(\n",
        "      final_result['detection_classes_names'], index_to_delete)\n",
        "  final_result['detection_masks_reframed'] = np.delete(\n",
        "      final_result['detection_masks_reframed'], index_to_delete, axis=0)\n",
        "\n",
        "  if final_result['num_detections'][0]:\n",
        "\n",
        "    # Calculate properties of each object for object tracking purpose.\n",
        "    dfs, cropped_masks = (\n",
        "        color_and_property_extractor.extract_properties_and_object_masks(\n",
        "            final_result, height, width, image_np_cp))\n",
        "    features = pd.concat(dfs, ignore_index=True)\n",
        "    features['image_name'] = selected_image\n",
        "    features.rename(columns={\n",
        "        'centroid-0':'y',\n",
        "        'centroid-1':'x',\n",
        "        'bbox-0':'bbox_0',\n",
        "        'bbox-1':'bbox_1',\n",
        "        'bbox-2':'bbox_2',\n",
        "        'bbox-3':'bbox_3'\n",
        "    }, inplace=True)\n",
        "\n",
        "    # Find color of each object.\n",
        "    dominant_colors = [*map(color_and_property_extractor.find_dominant_color, cropped_masks)]\n",
        "    specific_color_names = [*map(color_and_property_extractor.get_color_name, dominant_colors)]\n",
        "    generic_color_names = color_and_property_extractor.get_generic_color_name(dominant_colors)\n",
        "    features['color'] = generic_color_names if generic_color_names else specific_color_names"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7gAmUStY-Ch0"
      },
      "outputs": [],
      "source": [
        "# Showing all properties of 1st object in an image\n",
        "features.iloc[0]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bkN8UgZ0HcM-"
      },
      "source": [
        "## Visualization of masks"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MErKSA5jHbEq"
      },
      "outputs": [],
      "source": [
        "%matplotlib inline\n",
        "image_new = image_np_cp.numpy().copy()\n",
        "\n",
        "if 'detection_masks_reframed' in final_result:\n",
        "  final_result['detection_masks_reframed'] = final_result['detection_masks_reframed'].astype(np.uint8)\n",
        "\n",
        "viz_utils.visualize_boxes_and_labels_on_image_array(\n",
        "      image_new,\n",
        "      final_result['detection_boxes'][0],\n",
        "      (final_result['detection_classes'] + 0).astype(int),\n",
        "      final_result['detection_scores'][0],\n",
        "      category_index=category_index,\n",
        "      use_normalized_coordinates=True,\n",
        "      max_boxes_to_draw=100,\n",
        "      min_score_thresh=0.6,\n",
        "      agnostic_mode=False,\n",
        "      instance_masks=final_result.get('detection_masks_reframed', None),\n",
        "      line_thickness=2)\n",
        "\n",
        "plt.figure(figsize=(20,10))\n",
        "plt.imshow(image_new)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1z44wBf3Hovc"
      },
      "outputs": [],
      "source": [
        "# Visualize binary masks.\n",
        "mask = np.zeros_like(final_result['detection_masks_reframed'][0])\n",
        "for i in final_result['detection_masks_reframed']:\n",
        "  mask += i\n",
        "\n",
        "plt.figure(figsize=(20,10))\n",
        "plt.imshow(mask)\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oW6MDY03eIa4"
      },
      "outputs": [],
      "source": [
        "# Visualize color detection\n",
        "%matplotlib inline\n",
        "image_new = image_np_cp.numpy().copy()\n",
        "\n",
        "if 'detection_masks_reframed' in final_result:\n",
        "  final_result['detection_masks_reframed'] = final_result['detection_masks_reframed'].astype(np.uint8)\n",
        "\n",
        "color_labels = [f\"{specific_color_name}_{generic_color_names[i]}\" for i, specific_color_name in enumerate(specific_color_names)]\n",
        "detected_colors = np.unique(color_labels)\n",
        "color_inverse_index = { v: index[0]+1 for index, v in np.ndenumerate(detected_colors) }\n",
        "color_category = [color_inverse_index[color] for color in color_labels]\n",
        "color_index = labels.categories_dictionary(detected_colors)\n",
        "\n",
        "viz_utils.visualize_boxes_and_labels_on_image_array(\n",
        "      image_new,\n",
        "      final_result['detection_boxes'][0],\n",
        "      (np.array(color_category)).astype(int),\n",
        "      final_result['detection_scores'][0],\n",
        "      category_index=color_index,\n",
        "      use_normalized_coordinates=True,\n",
        "      max_boxes_to_draw=100,\n",
        "      min_score_thresh=0.6,\n",
        "      agnostic_mode=False,\n",
        "      instance_masks=final_result.get('detection_masks_reframed', None),\n",
        "      line_thickness=2,\n",
        "      skip_scores=True,\n",
        "      mask_alpha=0)\n",
        "\n",
        "plt.figure(figsize=(20,10))\n",
        "plt.imshow(image_new)\n",
        "plt.show()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
